{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "3b05af3b",
   "metadata": {},
   "source": [
    "(tune-mnist-keras)=\n",
    "\n",
    "# Using Keras & TensorFlow with Tune\n",
    "\n",
    "```{image} /images/tf_keras_logo.jpeg\n",
    ":align: center\n",
    ":alt: Keras & TensorFlow Logo\n",
    ":height: 120px\n",
    ":target: https://www.keras.io\n",
    "```\n",
    "\n",
    "```{contents}\n",
    ":backlinks: none\n",
    ":local: true\n",
    "```\n",
    "\n",
    "## Example"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "19e3c389",
   "metadata": {},
   "outputs": [],
   "source": [
    "import argparse\n",
    "import os\n",
    "\n",
    "from filelock import FileLock\n",
    "from tensorflow.keras.datasets import mnist\n",
    "\n",
    "import ray\n",
    "from ray import tune\n",
    "from ray.tune.schedulers import AsyncHyperBandScheduler\n",
    "from ray.tune.integration.keras import TuneReportCallback\n",
    "\n",
    "\n",
    "def train_mnist(config):\n",
    "    # https://github.com/tensorflow/tensorflow/issues/32159\n",
    "    import tensorflow as tf\n",
    "\n",
    "    batch_size = 128\n",
    "    num_classes = 10\n",
    "    epochs = 12\n",
    "\n",
    "    with FileLock(os.path.expanduser(\"~/.data.lock\")):\n",
    "        (x_train, y_train), (x_test, y_test) = mnist.load_data()\n",
    "    x_train, x_test = x_train / 255.0, x_test / 255.0\n",
    "    model = tf.keras.models.Sequential(\n",
    "        [\n",
    "            tf.keras.layers.Flatten(input_shape=(28, 28)),\n",
    "            tf.keras.layers.Dense(config[\"hidden\"], activation=\"relu\"),\n",
    "            tf.keras.layers.Dropout(0.2),\n",
    "            tf.keras.layers.Dense(num_classes, activation=\"softmax\"),\n",
    "        ]\n",
    "    )\n",
    "\n",
    "    model.compile(\n",
    "        loss=\"sparse_categorical_crossentropy\",\n",
    "        optimizer=tf.keras.optimizers.SGD(lr=config[\"lr\"], momentum=config[\"momentum\"]),\n",
    "        metrics=[\"accuracy\"],\n",
    "    )\n",
    "\n",
    "    model.fit(\n",
    "        x_train,\n",
    "        y_train,\n",
    "        batch_size=batch_size,\n",
    "        epochs=epochs,\n",
    "        verbose=0,\n",
    "        validation_data=(x_test, y_test),\n",
    "        callbacks=[TuneReportCallback({\"mean_accuracy\": \"accuracy\"})],\n",
    "    )\n",
    "\n",
    "\n",
    "def tune_mnist(num_training_iterations):\n",
    "    sched = AsyncHyperBandScheduler(\n",
    "        time_attr=\"training_iteration\", max_t=400, grace_period=20\n",
    "    )\n",
    "\n",
    "    analysis = tune.run(\n",
    "        train_mnist,\n",
    "        name=\"exp\",\n",
    "        scheduler=sched,\n",
    "        metric=\"mean_accuracy\",\n",
    "        mode=\"max\",\n",
    "        stop={\"mean_accuracy\": 0.99, \"training_iteration\": num_training_iterations},\n",
    "        num_samples=10,\n",
    "        resources_per_trial={\"cpu\": 2, \"gpu\": 0},\n",
    "        config={\n",
    "            \"threads\": 2,\n",
    "            \"lr\": tune.uniform(0.001, 0.1),\n",
    "            \"momentum\": tune.uniform(0.1, 0.9),\n",
    "            \"hidden\": tune.randint(32, 512),\n",
    "        },\n",
    "    )\n",
    "    print(\"Best hyperparameters found were: \", analysis.best_config)\n",
    "\n",
    "\n",
    "if __name__ == \"__main__\":\n",
    "    parser = argparse.ArgumentParser()\n",
    "    parser.add_argument(\n",
    "        \"--smoke-test\", action=\"store_true\", help=\"Finish quickly for testing\"\n",
    "    )\n",
    "    parser.add_argument(\n",
    "        \"--server-address\",\n",
    "        type=str,\n",
    "        default=None,\n",
    "        required=False,\n",
    "        help=\"The address of server to connect to if using \" \"Ray Client.\",\n",
    "    )\n",
    "    args, _ = parser.parse_known_args()\n",
    "    if args.smoke_test:\n",
    "        ray.init(num_cpus=4)\n",
    "    elif args.server_address:\n",
    "        ray.init(f\"ray://{args.server_address}\")\n",
    "\n",
    "    tune_mnist(num_training_iterations=5 if args.smoke_test else 300)\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "d7e46189",
   "metadata": {},
   "source": [
    "## More Keras and TensorFlow Examples\n",
    "\n",
    "- {doc}`/tune/examples/includes/pbt_memnn_example`: Example of training a Memory NN on bAbI with Keras using PBT.\n",
    "- {doc}`/tune/examples/includes/tf_mnist_example`: Converts the Advanced TF2.0 MNIST example to use Tune\n",
    "  with the Trainable. This uses `tf.function`.\n",
    "  Original code from tensorflow: https://www.tensorflow.org/tutorials/quickstart/advanced\n",
    "- {doc}`/tune/examples/includes/pbt_tune_cifar10_with_keras`:\n",
    "  A contributed example of tuning a Keras model on CIFAR10 with the PopulationBasedTraining scheduler.\n"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "orphan": true
 },
 "nbformat": 4,
 "nbformat_minor": 5
}